
"""

AUTHOR: Marco Marchetti, Bruce A. Edgar Lab, Huntsman Cancer Institute, Salt Lake City, 84112, UT, USA

DESCRIPTION: This script is used for local image registration of adult Drosophlila midgut fluorescent imaging time-lapses.
Movies are converted to 8bit, normalized by subtracting mean and dividng by the standard deviation, converted to maximum intensity
or focused projections, then divided in partially overlapping regions of interest (ROIs). Each ROI is then XY-registered via cross-correlation
between frames and then the most in-focus Z slice is found. ROIs are finally exported.

"""

def displayOptions():
    
    # Displaying operating modes
    print("\nChoose one of the following options by typing the corresponding number:\n")
    print(" 1 - Load an image")
    print(" 2 - Find ROIs and register them")
    print(" 3 - Register a specified position")
    print(" 4 - Import ROIs coordinates")
    print(" 5 - Export ROIs")
    print(" 6 - Quit")

def movieImporter():
    
    # Import parameters for the movie to be registered
    print("Please enter the following parameters in order, separated by a space:")
    print("FileName Frames Z-slices Channels")
    print("e.g. Test_Image.tif 49 10 2")
    print("N.B. File name MUST NOT contain spaces")
    validity_switch = 0
    while not validity_switch: # Checking validity of input
        try:
            # Parsing parameters and loading image
            file, f, z, c = input().split(" ")
            mov = tifffile.imread(file)
            shape = (int(f), int(z), int(c), mov.shape[-2], mov.shape[-1])
            mov = mov.reshape(shape)
            validity_switch = 1
        except:
            print("\nOne or more parameters is not valid, please insert them again...")
    print("\nImage has been successfully imported")
    
    return mov, file

def otsuThreshold(profile):
    
    # Generating histogram
    counts = numpy.histogram(profile, bins = 256)[0]
    counts = [c * 100 / max(counts) for c in counts] # Scaling counts to avoid overflow errors
    all_counts = sum(counts)
    sums = [c * i for c,i in zip(counts, range(256))]
    
    # Defining between class variances for each possible pixel value
    thresholds = []
    for t in range(numpy.min(profile) + 1, numpy.max(profile) - 1):
        try:
            background_weight = sum(counts[:t]) / all_counts
            background_mean = sum(sums[:t]) / sum(counts[:t])
            foreground_weight = sum(counts[t:]) / all_counts
            foreground_mean = sum(sums[t:]) / sum(counts[t:])
            between_class_variance = background_weight * foreground_weight * (background_mean - foreground_mean)**2
            thresholds.append((t, between_class_variance))
        except:
            thresholds.append((t, -1))
    
    # Finding best threshold value, i.e. the one for which the between class variances is highest
    best_threshold = max(thresholds, key = lambda thr: thr[1])[0]
    
    return best_threshold

def makeProjection(mov, mode):
    
    # Converting movie to maximum projection
    if mode == "max":
        print("Creating movie maximum intensity z-projection")
        projection = numpy.array([numpy.max(m, 0) for m in mov])
    
    # Converting movie to focused projection
    else:
        print("Creating movie focused z-projection")
        projection = []
        for t in range(len(mov)):
            t_projection = numpy.zeros(mov[t].shape[1:])
            focus_map = getFocusMap(mov[t])
            for z in range(numpy.min(focus_map), numpy.max(focus_map) + 1):
                t_projection[focus_map == z] = mov[t][z, focus_map == z]
            projection.append(t_projection)
        projection = numpy.array(projection)
    
    return projection

def getFocusMap(stack):
    
    # Defining sobel operators
    sobel_1 = numpy.array([[[1, 0, -1], [2, 0, -2], [1, 0, -1]]])
    sobel_2 = numpy.array([[[1, 2, 1], [0, 0, 0], [-1, -2, -1]]])
    
    # Median filtering the image for noise reduction
    stack = medfil(stack, (1, 3, 3), mode="reflect").astype(float)
    
    # Finding edges
    edges = correlate(stack, sobel_1, mode="reflect")**2 + correlate(stack, sobel_2, mode="reflect")**2
    
    # Removing artifacts at image borders
    edges[:, 0] = 0
    edges[:, -1] = 0
    edges[:, :, 0] = 0
    edges[:, :, -1] = 0
    
    # Finding local maxima and generating a focus map
    edges = maxfil(edges, (1, 3, 3), mode="reflect")
    edges_map = numpy.argmax(edges, 0)
    
    return edges_map

def subdivideFrame(mask, w, f):
    
    # Median filtering the image for noise reduction
    mask = minfil(mask, (3, 3))
    
    # Defining the minimum number of pixels above threshold (5% of total) that a ROI must have to be deemed interesting
    min_interesting_pixels = w**2 * 5 / 100
    
    # Finding local centers of mass
    y_pos, x_pos = numpy.where(mask == 1)
    if len(y_pos) == 0:
        return []
    y_start, y_stop = max(0, min(y_pos) - 20), min(mask.shape[0] - 1, max(y_pos) + 20) - int(w/2) # -+20 to give a buffer
    x_start, x_stop = max(0, min(x_pos) - 20), min(mask.shape[1] - 1, max(x_pos) + 20) - int(w/2) # -+20 to give a buffer
    centers = []
    for y0 in numpy.arange(y_start, y_stop, w/2).astype(int):
        if y0 == y_stop - (y_stop - y_start) % (w/2):
            y1 = y_stop
        else:
            y1 = y0 + w
        for x0 in numpy.arange(x_start, x_stop, w/2).astype(int):
            if x0 == x_stop - (x_stop - x_start) % (w/2):
                x1 = x_stop
            else:
                x1 = x0 + w
            cutout = mask[y0 : y1, x0 : x1]
            if numpy.sum(cutout == 1) < min_interesting_pixels: # Too few interesting pixels to care (i.e. < 5%)
                continue
            center_of_mass = tuple([int(numpy.mean(pos)) + start for pos,start in zip(numpy.where(cutout == 1), (y0, x0))])
            centers.append(center_of_mass)
    
    # Merging centers too close to each other (i.e. having distance < w / 2)
    distance = lambda a,b: ((a[0] - b[0])**2 + (a[1] - b[1])**2)**0.5
    toggle = 1
    while toggle == 1:
        cleaned, toggle = [], 0
        for a in centers:
            close_centers = [b for b in centers if (b != a) and (distance(a, b) < w / 4)] + [a]
            if len(close_centers) == 1:
                cleaned.append(a)
                continue
            new_center = tuple([int(sum([c[i] for c in close_centers]) / len(close_centers)) for i in range(2)])
            if new_center in cleaned: # Point was already defined
                continue
            cleaned.append(new_center)
            toggle = 1
        centers = cleaned.copy()
    
    # Adding frame information and converting each coordinate as a list of coordinates
    centers = [[(f, c[0], c[1])] for c in centers]
    
    return centers

def compareTimePoints(f1, f2, coords, t, w, d, correl, f_start, f_stop, new_roi_search):
    
    # Tracking ROIs
    for i in range(len(coords)):
        
        # Skipping coordinate if already found in frame f2 or if prevously lost
        if f_stop in [c[0] for c in coords[i]]: # Image was already registered to this time-point
            continue
        if f_start not in [c[0] for c in coords[i]]: # ROI was probably lost
            continue
        
        # Defining a cutout of frame f1 based on coords[i] and cross-correlating it to a search area defined in frame f2
        c = [c for c in coords[i] if c[0] == f_start][0][1:]
        y_slice0 = slice(max(0, c[0] - int(w/2)), min(f1.shape[0], c[0] + int(w/2)), None)
        x_slice0 = slice(max(0, c[1] - int(w/2)), min(f1.shape[1], c[1] + int(w/2)), None)
        cutout = f1[y_slice0, x_slice0]
        y_slice1 = slice(max(0, c[0] - int(w/2) - d), min(f1.shape[0], c[0] + int(w/2) + d), None)
        x_slice1 = slice(max(0, c[1] - int(w/2) - d), min(f1.shape[1], c[1] + int(w/2) + d), None)
        search_area = f2[y_slice1, x_slice1]
        cross_correl = match_template(search_area, cutout, pad_input = True, mode = "constant", constant_values = 0.)
        if numpy.max(cross_correl) < correl:
            if new_roi_search: # Only mark a ROI as lost if moving forward in time
                coords[i].append("ROI Lost")
            continue
        y_pos, x_pos = [cc[0] for cc in list(numpy.where(cross_correl == numpy.max(cross_correl)))]
        
        # Adjusting coordinates due to offsets introduced by slicing (Useful if ROI is close to the edge of the imaging field)
        y_pos += y_slice1.start - int(cutout.shape[0] / 2 - (c[0] - y_slice0.start))
        x_pos += x_slice1.start - int(cutout.shape[1] / 2 - (c[1] - x_slice0.start))
        coords[i].append((f_stop, y_pos, x_pos))
    
    # Finding new ROIs if the coverage of the frame f2 by existing ROIs is incomplete
    if new_roi_search:
        coverage = numpy.zeros(f2.shape)
        for c in coords:
            if c[-1][0] == f_stop:
                coverage[c[-1][1], c[-1][2]] = 1
        coverage = maxfil(coverage, (w, w))
        coverage = (f2 >= t) * (1 - coverage)
        new_coordinates = subdivideFrame(coverage, w, f_stop)
        coords.extend(new_coordinates)
    
    return coords

def refineCoordinates(mov, coords, w):
    
    # Coordinates are refined so that the most in-focus slice is always centered.
    # This is faster that doing cross-correlation and works well enough
    refined_c = []
    for c in coords:
        if c == "ROI Lost":
            break
        cutout = mov[c[0]][:, max(0, c[1] - int(w/2)) : c[1] + int(w/2), max(0, c[2] - int(w/2)) : c[2] + int(w/2)]
        most_focused_z = numpy.argmax([getFocus(pic) for pic in cutout])
        new_coords = (c[0], most_focused_z, c[1], c[2])
        refined_c.append(new_coords)
    
    return refined_c

def getFocus(pic):
    
    # Defining sobel operators
    sobel_1 = numpy.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) #Used for focus calculation
    sobel_2 = numpy.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) #Used for focus calculation
    
    # Median filtering the image for noise reduction
    pic = medfil(pic, (3, 3), mode="reflect").astype(float) # Smoothening
    
    # Finding edges
    edges = correlate(pic, sobel_1, mode="reflect")**2 + correlate(pic, sobel_2, mode="reflect")**2
    
    # Removing artifacts at image borders
    edges[0] = 0
    edges[-1] = 0
    edges[:, 0] = 0
    edges[:, -1] = 0
    
    # Computing focus score as the sum of pixel values after edge detection
    pic_focus = numpy.sum(edges)
    
    return pic_focus

def exportCoordinates(file, coords):
    
    # Saving coordinates to a .tsv file. This can be imported and used to extract ROIs
    output_text = "\n".join(["\t".join(["ROI-" + str(i)] + [str(c) for c in co]) for co,i in zip(coords, range(1, len(coords) + 1))])
    output = open(file[:-4] + "_ROIs.tsv", "w")
    output.write(output_text)
    output.close()

def inputPointsOfInterest():
    
    # Defining points of interest for local registration
    toggle1, toggle2, coordinates = 1, 1, []
    while toggle1:
        
        # Defining registration parameters.
        if toggle2:
            print("Please enter the following parameters in order, separated by a space:")
            print("RegistrationChannel WindowSize_(px) MaxDistance_(px) MinCorrelationScore_(0-1) ProjectionMode_(Focus_or_Max)")
            print("e.g. 1 256 500 0.5 Focus")
            validity_switch = 0
            while not validity_switch: # Checking validity of input
                try:
                    # Parsing parameters and loading image
                    rc, ws, md, mc, projection_mode = input().lower().split(" ")
                    rc, ws, md, mc = int(rc) - 1, int(ws), int(md), float(mc)
                    if projection_mode in ["focus", "max"]:
                        validity_switch = 1
                    else:
                        print("\nOne or more parameters is not valid, please insert them again...")
                        validity_switch = 0
                except:
                    print("\nOne or more parameters is not valid, please insert them again...")
        
        # Defining points of interest
        else:
            print("\nPlease enter the coordinates of the point of interest, separated by a space:")
            print("X Y T")
            print("e.g. 32 47 2")
            print("N.B. X and Y coordinates are defined in pixels")
            validity_switch = 0
            while not validity_switch: # Checking validity of input
                try:
                    # Parsing parameters and loading image
                    x, y, t = [int(i) for i in input().split(" ")]
                    x, y, t = x - 1, y - 1, t - 1
                    coordinates.append([(t, y, x)])
                    validity_switch = 1
                except:
                    print("\nOne or more parameters is not valid, please insert them again...")
        
        # This makes sure that the registration parameters are only asked for once
        if toggle2:
            toggle2 = 0
            continue
        
        # User decides whether to input an additional point or not
        print("\nAdd another point? Type 'y' if so, any other key otherwise")
        toggle1 = 1 if input().lower() == "y" else 0
    
    return rc, ws, md, mc, coordinates, projection_mode

def importCoordinates():
    
    # Import parameters for the ROI coordinates .tsv file
    print("Please enter the ROIs coordinates file name:")
    print("e.g. Sample_ROIs.tsv")
    print("N.B. File name MUST NOT contain spaces")
    validity_switch = 0
    while not validity_switch: # Checking validity of input
        try:
            # Parsing parameters and loading image
            file = input()
            raw = open(file).read().split("\n")
            parsed = []
            for r in raw:
                if not len(r):
                    continue
                r = r.replace('"', '') # This is necessary if the ROI file was manually modified
                new_parsed = [tuple([int(c) for c in co[1:-1].split(", ")]) if co != "ROI Lost" else "ROI Lost" for co in r.split("\t")[1:]]
                parsed.append(new_parsed)
            validity_switch = 1
        except:
            print("\nThe file name is incorrect, please insert it again...")
    print("\nROIs coordinates have been successfully imported")
    
    return parsed

def exportROIs(mov, coords, file):
    
    # Export parameters for registered ROIs
    print("Please enter the following parameters in order, separated by a space:")
    print("WindowSize_(px) StacksBeforeCenter_(slices) StacksAfterCenter_(slices)")
    print("e.g. 512 3 1")
    print("N.B. The number of stacks before/after the central one can be 0")
    validity_switch = 0
    while not validity_switch: # Checking validity of input
        try:
            # Parsing parameters and loading image
            ws, sb, sa = [int(i) for i in input().split(" ")]
            validity_switch = 1
        except:
            print("\nOne or more parameters is not valid, please insert them again...")
    
    # Extracting ROIs
    rois = [[] for _ in range(len(coords))]
    for frame in range(mov.shape[0]):
        for co,i in zip(coords, range(len(coords))):
            
            # Checking if coordinates exist for this ROI at the current frame
            try:
                c = [c for c in co if c[0] == frame][0]
            except: # ROI not found at this time-point
                continue
            
            # Expanding frame to avoid ROI clipping
            expanded = numpy.zeros((mov.shape[1] + sb + sa, mov.shape[2], mov.shape[3] + 2 * ws, mov.shape[4] + 2 * ws), mov.dtype)
            if sa != 0:
                expanded[sb : -sa, :, ws : -ws, ws : -ws] += mov[frame]
            elif sb != 0 and sa == 0:
                expanded[sb : , :, ws : -ws, ws : -ws] += mov[frame]
            else:
                expanded[:, :, ws : -ws, ws : -ws] += mov[frame]
            
            # Storing ROI
            roi = numpy.copy(expanded[c[1] : c[1] + sb + sa + 1, :, c[2] + int(ws/2) : c[2] + int(ws/2) + ws, c[3] + int(ws/2) : c[3] + int(ws/2) + ws])
            rois[i].append(roi)
    
    # Freeing up memory
    del mov
    
    # Saving ROIs
    if not isdir(file[:-4] + "_ROIs"):
        mkdir(file[:-4] + "_ROIs")
    for roi,i in zip(rois, range(len(rois))):
        roi = numpy.array(roi)
        start_frame, stop_frame = str(coords[i][0][0] + 1), str(coords[i][-1][0] + 1)
        save_name = file[:-4] + "_ROIs\ROI-" + str(i + 1) + "_T=" + start_frame + "-" + stop_frame + ".tif"
        tifffile.imsave(save_name, roi, imagej = True)

###MAIN

print("\nLoading dependencies...")

try:
    import numpy
    from scipy.ndimage import correlate
    from scipy.ndimage import median_filter as medfil
    from scipy.ndimage import maximum_filter as maxfil
    from scipy.ndimage import minimum_filter as minfil
    from skimage.feature import match_template
    import tifffile
    import time as tm
    from os import mkdir
    from os.path import isdir
except:
    print("One or more dependencies are not installed.\nAlso, make sure your terminal has been activated")
    exit()

key = ""
options_switch = 0 #Used to determine if options have already been displayed

# Running script until users decides to stop
while key != "6":
    
    # Displying running modes
    if options_switch == 0:
        options_switch = 1
        displayOptions()
    key = input()
    print("\n")
    
    # Import movie
    if key == "1":
        options_switch = 0
        movie, file_name = movieImporter()
        movie_shape = movie.shape
    
    # Finding regions of interest and registering movie
    elif key == "2":
        options_switch = 0
        
        # Checking that movie is loaded in its original form
        try:
            movie
        except:
            print("Error, no image is loaded...")
            continue
        if type(movie) == list: # Reloading movie if previously normalized
            print("Reloading movie...")
            movie = tifffile.imread(file_name)
            movie = movie.reshape(movie_shape)
        
        # Import parameters for registration
        print("Please enter the following parameters in order, separated by a space:")
        print("RegistrationChannel WindowSize_(px) MaxDistance_(px) MinCorrelationScore_(0-1) ProjectionMode_(Focus_or_Max)")
        print("e.g. 1 256 500 0.5 Focus")
        validity_switch = 0
        while not validity_switch: # Checking validity of input
            try:
                # Parsing parameters and loading image
                rc, ws, md, mc, projection_mode = input().lower().split(" ")
                rc, ws, md, mc = int(rc) - 1, int(ws), int(md), float(mc)
                if projection_mode in ["focus", "max"]:
                    validity_switch = 1
                else:
                    print("\nOne or more parameters is not valid, please insert them again...")
                    validity_switch = 0
            except:
                print("\nOne or more parameters is not valid, please insert them again...")
        
        # Calculating run time
        elapsed_time = tm.time()
        
        # Converting movie to 8bit to make calculations easier (There're more direct ways, but the following is less memory intensive)
        print("\nProceeding with data normalization (This may take a few minutes)")
        conversion_factor = 255 / numpy.max(movie[:, :, rc])
        movie = [mov for mov in movie[:, :, rc]]
        for m in range(movie_shape[0]):
            movie[m] = numpy.uint8(movie[m] * conversion_factor)
        
        # Calculating normalizing factors (as the average of the means and stdevs of each frame, which is less memory intensive
        # than the more direct way)
        mean = numpy.mean([numpy.mean(m) for m in movie])
        stdev = numpy.mean([numpy.std(m) for m in movie])
        
        # Calculating foreground value and normalizing movie (There're more direct ways, but the following is less memory intensive)
        threshold = (otsuThreshold(numpy.uint8(numpy.array(movie))) - mean) / stdev
        for m in range(movie_shape[0]):
            movie[m] = (movie[m].astype(float) - mean) / stdev
        
        # Transforming the movie to a maximum intensity projectio or focused projection
        mov_projection = makeProjection(movie, projection_mode) # XY registration is done on a projection of the movie
        
        # Finding starting regions of interest
        print("Finding initial regions of interest\n")
        coordinates, frame = [], 0
        while frame < mov_projection.shape[0] and not len(coordinates):
            coordinates = subdivideFrame(mov_projection[frame] >= threshold, ws, frame)
            frame += 1
        if not len(coordinates):
            print("\nNo good points found, try option 3")
            continue
        
        # Forward pass: XY-registering known ROIs and finding new ones
        for frame in range(mov_projection.shape[0] - 1):
            print("Registering frames " + str(frame + 1) + " and " + str(frame + 2))
            coordinates = compareTimePoints(mov_projection[frame], mov_projection[frame + 1], coordinates, threshold, ws, md, mc, frame, frame + 1, 1)
        
        # Backwards pass: XY-registering newly found ROIs backwards in time
        for frame in range(mov_projection.shape[0] - 1, 0, -1):
            print("Registering frames " + str(frame + 1) + " and " + str(frame))
            coordinates = compareTimePoints(mov_projection[frame], mov_projection[frame - 1], coordinates, threshold, ws, md, mc, frame, frame - 1, 0)
        
        # Sorting coordinates based on movie frame
        coordinates = [sorted(co) if "ROI Lost" not in co else sorted([c for c in co if c != "ROI Lost"]) + ["ROI Lost"] for co in coordinates]
        
        # Refining registration along the z axis by defining the center slice as the most in-focus one
        refined = []
        for coords,i in zip(coordinates, range(len(coordinates))):
            print("\nRefining coordinates of ROI " + str(i + 1) + " along the z axis")
            new_coordinates = refineCoordinates(movie, coords, ws)
            refined.append(new_coordinates)
        coordinates = refined.copy()
        
        # Exporting coordinates
        exportCoordinates(file_name, coordinates)
        elapsed_time = tm.time() - elapsed_time
        
        # Printing number of found ROIs and run time
        print("\nTask complete! Found " + str(len(coordinates)) + " regions of interest")
        print("Time elapsed: "+str(round((elapsed_time/60)))+"' "+str(round(elapsed_time%60,2))+'"\n')
    
    # Registering user-defined regions of interest
    elif key == "3":
        options_switch = 0
        
        # Checking that movie is loaded in its original form
        try:
            movie
        except:
            print("Error, no image is loaded...")
            continue
        if type(movie) == list: # Reloading movie if previously normalized
            print("Reloading movie...")
            movie = tifffile.imread(file_name)
            movie = movie.reshape(movie_shape)
        
        # Defining points of interest via user input
        rc, ws, md, mc, coordinates, projection_mode = inputPointsOfInterest()
        
        # Calculating run time
        elapsed_time = tm.time()
        
        # Converting movie to 8bit to make calculations easier (There're more direct ways, but the following is less memory intensive)
        print("\nProceeding with data normalization (This may take a few minutes)")
        conversion_factor = 255 / numpy.max(movie[:, :, rc])
        movie = [mov for mov in movie[:, :, rc]]
        for m in range(movie_shape[0]):
            movie[m] = numpy.uint8(movie[m] * conversion_factor)
        
        # Calculating normalizing factors (as is the average of the means and stdevs of each frame, which is less memory intensive
        # than the more direct way)
        mean = numpy.mean([numpy.mean(m) for m in movie])
        stdev = numpy.mean([numpy.std(m) for m in movie])
        
        # Calculating foreground value and normalizing movie
        threshold = (otsuThreshold(numpy.uint8(numpy.array(movie))) - mean) / stdev
        for m in range(movie_shape[0]):
            movie[m] = (movie[m].astype(float) - mean) / stdev
        
        # Transforming the movie to a maximum intensity projectio or focused projection
        mov_projection = makeProjection(movie, projection_mode) # Registration is initially done on a projection of the movie
        print("\n")
        
        # Forward pass: XY-registering ROIs forward in time
        for frame in range(min([c[0][0] for c in coordinates]), mov_projection.shape[0] - 1):
            print("Registering frames " + str(frame + 1) + " and " + str(frame + 2))
            coordinates = compareTimePoints(mov_projection[frame], mov_projection[frame + 1], coordinates, threshold, ws, md, mc, frame, frame + 1, 0)
        
        # Backwards pass: XY-registering ROIs backward in time
        for frame in range(mov_projection.shape[0] - 1, 0, -1):
            print("Registering frames " + str(frame + 1) + " and " + str(frame))
            coordinates = compareTimePoints(mov_projection[frame], mov_projection[frame - 1], coordinates, threshold, ws, md, mc, frame, frame - 1, 0)
        
        # Sorting coordinates based on time
        coordinates = [sorted(co) if "ROI Lost" not in co else sorted([c for c in co if c != "ROI Lost"]) + ["ROI Lost"] for co in coordinates]
        # Refining registration along the z axis
        print("\nRefining coordinates along the z axis")
        refined = []
        for coords in coordinates:
            new_coordinates = refineCoordinates(movie, coords, ws)
            refined.append(new_coordinates)
        coordinates = refined.copy()
        # Exporting coordinates
        exportCoordinates(file_name, coordinates)
        elapsed_time = tm.time() - elapsed_time
        
        # Printing number of found ROIs and run time
        print("\nTask complete!")
        print("Time elapsed: "+str(round((elapsed_time/60)))+"' "+str(round(elapsed_time%60,2))+'"\n')
        
    # Import .tsv file with ROIs coordinates
    elif key == "4":
        options_switch = 0
        coordinates = importCoordinates()
    
    # Export ROIs
    elif key == "5":
        options_switch = 0
        
        # Checking that movie is loaded in its original form and that ROIs coordinates are defined
        try:
            movie
        except:
            print("Error, no image is loaded...")
            continue
        try:
            coordinates
        except:
            print("Error, no ROIs found...")
            continue
        if type(movie) == list: # Reloading movie if option 4 was directly selected after 3
            print("Reloading movie...")
            movie = tifffile.imread(file_name)
            movie = movie.reshape(movie_shape)
        
        # Export ROIs
        exportROIs(movie, coordinates, file_name)
    
    # Terminating script
    elif key == "6":
        print("Terminating script...\n")
    
    # Error message for unrecognized running mode
    else:
        print("Unrecognized key, please try again...\n")